{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Signal echoing\n",
    "\n",
    "Echoing signal `n` steps is an example of synchronized many-to-many task."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from res.sequential_tasks import EchoData\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import numpy as np\n",
    "\n",
    "torch.manual_seed(1)\n",
    "np.random.seed(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 5\n",
    "echo_step = 3\n",
    "series_length = 20_000\n",
    "BPTT_T = 20\n",
    "\n",
    "# EchoData provides input and target data for training a network to\n",
    "# echo a `series_length`-long stream of data. `.x_batch` contains the input series,\n",
    "# it has shape `[batch_size, series_length]`; `.y_batch` contains the target data,\n",
    "# it has the same shape as `.x_batch`.\n",
    "#\n",
    "# Unlike other training data in this course, successive batches from a single `EchoData`\n",
    "# object draw from the same stream. For example, in 08-seq_classification, training data\n",
    "# has the following format:\n",
    "#\n",
    "#   [[S11 S12...S1N], [S21 S22...S2N], ..., [SM1 SM2...SMN]]\n",
    "#\n",
    "# where `SIJ` represents the `j`th sample drawn from the `i`th stream. \n",
    "#\n",
    "# However, `EchoData` output has the following format (slicing along the batch dimension):\n",
    "#\n",
    "#   [[S11 S21...S1N], [S1(N+1) S1(N+2)...S2(2N)], ..., [S1(MN) S1(MN+1)...SM(MNN)]]\n",
    "#\n",
    "# This means that successive batches of data drawn from the same `EchoData` object\n",
    "# are not independent.\n",
    "train_data = EchoData(\n",
    "    echo_step=echo_step,\n",
    "    batch_size=batch_size,\n",
    "    series_length=series_length,\n",
    "    truncated_length=BPTT_T\n",
    ")\n",
    "total_values_in_one_chunck = batch_size * BPTT_T\n",
    "train_size = len(train_data)\n",
    "\n",
    "test_data = EchoData(\n",
    "    echo_step=echo_step,\n",
    "    batch_size=batch_size,\n",
    "    series_length=series_length,\n",
    "    truncated_length=BPTT_T,\n",
    ")\n",
    "test_size = len(test_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Let's print first 20 timesteps of the first sequences to see the echo data:\n",
    "print('(1st input sequence)  x:', *train_data.x_batch[0, :20], '... ')\n",
    "print('(1st target sequence) y:', *train_data.y_batch[0, :20], '... ')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# batch_size different sequences are created:\n",
    "print('x_batch:', *(str(d)[1:-1] + ' ...' for d in train_data.x_batch[:, :20]), sep='\\n')\n",
    "print('x_batch size:', train_data.x_batch.shape)\n",
    "print()\n",
    "print('y_batch:', *(str(d)[1:-1] + ' ...' for d in train_data.y_batch[:, :20]), sep='\\n')\n",
    "print('y_batch size:', train_data.y_batch.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# In order to use RNNs data is organized into temporal\n",
    "# chunks of size [batch_size, T, feature_dim]\n",
    "print('x_chunk:', *train_data.x_chunks[0].squeeze(), sep='\\n')\n",
    "print('1st x_chunk size:', train_data.x_chunks[0].shape)\n",
    "print()\n",
    "print('y_chunk:', *train_data.y_chunks[0].squeeze(), sep='\\n')\n",
    "print('1st y_chunk size:', train_data.y_chunks[0].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SimpleRNN(nn.Module):\n",
    "    def __init__(self, input_size, rnn_hidden_size, output_size):\n",
    "        super().__init__()\n",
    "        self.rnn_hidden_size = rnn_hidden_size\n",
    "        self.rnn = torch.nn.RNN(\n",
    "            input_size=input_size,\n",
    "            hidden_size=rnn_hidden_size,\n",
    "            num_layers=1,\n",
    "            nonlinearity='relu',\n",
    "            batch_first=True\n",
    "        )\n",
    "        self.linear = torch.nn.Linear(\n",
    "            in_features=rnn_hidden_size,\n",
    "            out_features=1\n",
    "        )\n",
    "\n",
    "    def forward(self, x, hidden):\n",
    "        # In order to model the fact that successive batches belong to the same stream of data,\n",
    "        # we share the hidden state across successive invocations.\n",
    "        x, hidden = self.rnn(x, hidden)  \n",
    "        x = self.linear(x)\n",
    "        return x, hidden"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train():\n",
    "    model.train()\n",
    "    \n",
    "    # New epoch --> fresh hidden state\n",
    "    hidden = None   \n",
    "    correct = 0\n",
    "    for batch_idx in range(train_size):\n",
    "        data, target = train_data[batch_idx]\n",
    "        data, target = torch.from_numpy(data).float().to(device), torch.from_numpy(target).float().to(device)\n",
    "        optimizer.zero_grad()\n",
    "        if hidden is not None: hidden.detach_()\n",
    "        logits, hidden = model(data, hidden)\n",
    "        loss = criterion(logits, target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        pred = (torch.sigmoid(logits) > 0.5)\n",
    "        correct += (pred == target.byte()).int().sum().item()/total_values_in_one_chunck\n",
    "        \n",
    "    return correct, loss.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test():\n",
    "    model.eval()   \n",
    "    correct = 0\n",
    "    # New epoch --> fresh hidden state\n",
    "    hidden = None\n",
    "    with torch.no_grad():\n",
    "        for batch_idx in range(test_size):\n",
    "            data, target = test_data[batch_idx]\n",
    "            data, target = torch.from_numpy(data).float().to(device), torch.from_numpy(target).float().to(device)\n",
    "            logits, hidden = model(data, hidden)\n",
    "            \n",
    "            pred = (torch.sigmoid(logits) > 0.5)\n",
    "            correct += (pred == target.byte()).int().sum().item()/total_values_in_one_chunck\n",
    "\n",
    "    return correct"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_dim = 1 #since we have a scalar series\n",
    "h_units = 4\n",
    "\n",
    "model = SimpleRNN(\n",
    "    input_size=1,\n",
    "    rnn_hidden_size=h_units,\n",
    "    output_size=feature_dim\n",
    ").to(device)\n",
    "        \n",
    "criterion = torch.nn.BCEWithLogitsLoss()\n",
    "optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_epochs = 5\n",
    "\n",
    "for epoch in range(1, n_epochs+1):\n",
    "    correct, loss = train()\n",
    "    train_accuracy = float(correct)*100/ train_size\n",
    "    print(f'Train Epoch: {epoch}/{n_epochs}, loss: {loss:.3f}, accuracy {train_accuracy:.1f}%')\n",
    "\n",
    "#test    \n",
    "correct = test()\n",
    "test_accuracy = float(correct) * 100 / test_size\n",
    "print(f'Test accuracy: {test_accuracy:.1f}%')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Let's try some echoing\n",
    "my_input = torch.empty(1, 100, 1).random_(2)\n",
    "hidden = None\n",
    "my_out, _ = model(my_input.to(device), hidden)\n",
    "my_pred = torch.where(my_out > .5, \n",
    "                      torch.ones_like(my_out), \n",
    "                      torch.zeros_like(my_out)).cpu()\n",
    "print(my_input.view(1, -1).byte(), my_pred.view(1, -1).byte(), sep='\\n')\n",
    "\n",
    "# Calculate the expected output for our random input\n",
    "expected = np.roll(my_input, echo_step)\n",
    "expected[:, :echo_step] = 0\n",
    "correct = expected == my_pred.numpy()\n",
    "print(np.ndarray.flatten(correct))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
